import numpy as np
import gymnasium as gym
import re, copy


def strip_instance(name):
    return re.sub(r'\d+$', '', name)

non_state_factors = ["Action", "Reward", "Done", "Goal"]

class Environment(gym.Env):
    def __init__(self, frameskip = 1, horizon=200, variant="", fixed_limits=False, renderable=False, render_masks=False):
        ''' required attributes:
            num actions: int or None
            action_space: gym.Spaces
            action_shape: tuple of ints
            observation_space = gym.Spaces
            done: boolean
            reward: int
            seed_counter: int
            discrete_actions: boolean
            name: string
        All the below properties are set by the subclass
        '''
        # environment properties
        self.self_reset = True
        self.num_actions = None # this must be defined, -1 for continuous. Only needed for primitive actions
        self.name = "ABSTRACT_BASE" # required for an environment
        self.fixed_limits = False # uses normalization limits which are fixed across all objects
        self.discrete_actions = True
        self.frameskip = frameskip # no frameskip
        self.transpose = True # transposes the visual domain

        # spaces
        self.action_shape = (1,) # should be set in the environment, (1,) is for discrete action environments
        self.action_space = None # gym.spaces
        self.observation_space = None # raw space, gym.spaces
        self.pos_size = 2 # the dimensionality, should be set

        # state components
        self.frame = None # the image generated by the environment
        self.reward = Reward()
        self.done = Done()
        self.action = np.zeros(self.action_shape)
        self.extracted_state = None

        # running values
        self.itr = 0
        self.total_itr = 0
        self.max_steps = horizon if horizon > 0 else 1e12

        # factorized state properties
        self.all_names = [] # must be initialized, the names of all the objects including multi-instanced ones
        self.valid_names = list() # must be initialized, the names of all objects used in a particular trajectory. The environment should treat as nonexistent objects which are not part of this list
        self.num_objects = -1 # should be defined if valid, the number of objects (or virtual objects) in the flattened obs
        self.object_names = [] # must be initialized, a list of names that controls the ordering of things
        self.object_sizes = dict() # must be initialized, a dictionary of name to length of the state
        self.object_range = dict() # the minimum and maximum values for a given feature of an object
        self.object_dynamics = dict() # the most that an object can change in a single time step
        self.object_range_true = dict() # if using a fixed range, this stores the true range (for sampling)
        self.object_dynamics_true = dict() # if using a fixed dynamics range, this stores the true range (for sampling)
        self.object_instanced = dict() # name of object to max number of objects of that type
        self.object_proximal = dict() # name of object to whether that object has valid proximity
        self.object_name_dict = dict() # the string names to object classes
        self.instance_length = 0 # the total number of instances for the mask

        # proximity state components
        self.position_masks = dict()
        self.pos_size = 0 # the size of the position vector, if used for that object
        self.goal_based = False
        self.goal_idx = -1 # if a specific object, use self.all_names.find(<target object name>)
        self.goal_graph_idx = -1 # if not negative, then usually self.goal_idx - 1
        self.goal_space = None # if there is a goal space, this should be set

        # trace components # passive trace defaults to 1,0=1
        self.passive_trace = np.eye(2) # should be set in subclass
        self.passive_trace[1,0] = 1

    def step(self, action):
        '''
        self.save_path is the path to which to save files, and self.itr is the iteration number to be used for saving.
        The format of saving is: folders contain the raw state, names are numbers, contain 2000 raw states each
        obj_dumps contains the factored state
        empty string for save_path means no saving state
        matches the API of OpenAI gym by taking in action (and optional params)
        returns
            state as dict: next raw_state (image or observation) next factor_state (dictionary of name of object to tuple of object bounding box and object property)
            reward: the true reward from the environment
            done flag: if an episode ends, done is True
            info: a dict with additional info
        '''
        pass

    def reset(self):
        '''
        matches the API of OpenAI gym, resetting the environment
        returns:
            state as dict: next raw_state, next factor_state (dict with corresponding keys)
        '''
        pass

    def render(self, mode='human'):
        '''
        matches the API of OpenAI gym, rendering the environment
        returns None for human mode
        '''

    def close(self):
        '''
        closes and performs cleanup
        '''

    def seed(self, seed):
        '''
        numpy should be the only source of randomness, but override if there are more
        '''
        if seed < 0: seed = np.random.randint(10000)
        np.random.seed(seed)


    def get_state(self):
        '''
        Takes in an action and returns:
            dictionary with keys:
                raw_state (dictionary of name of object to raw state)
                factor_state (dictionary of name of object to tuple of object bounding box and object property)
        '''
        pass

    def get_info(self): # returns the info, the most important value is TimeLimit.truncated, can be overriden
        return {"TimeLimit.truncated": False}

    def get_itr(self):
        return self.itr

    def run(self, policy, iterations = 10000):
        
        full_state = self.get_state()
        for self.itr in range(iterations):
            action = policy.act(full_state)
            if action == -1: # signal to quit
                break
            full_state = self.step(action)

    def flatten_factored_state(self, factored_state):
        return np.concatenate([np.array(factored_state[n]) for n in self.all_names], axis=-1)

    def set_from_factored_state(self, factored_state, valid_names):
        '''
        from the factored state, sets the environment.
        If the factored state is not complete, then this function should do as good a reconstruction as possible
        might not be implemented for every environment
        '''
        pass

    def current_trace(self, names=None):
        # can optinally take in a names objdict of which objects to get traces of. Otherwise just gets all the traces
        if names is not None: 
            targets = [self.object_name_dict[names.target]] if type(self.object_name_dict[names.target]) != list else self.object_name_dict[names.target]
            traces = list()
            for target in targets:
                if self.object_name_dict[names.primary_parent].name in target.interaction_trace:
                    traces.append(1)
                else:
                    traces.append(0)
            return traces
        else: 
            targets = [self.object_name_dict[n] for n in self.all_names]
            parents = self.all_names
            traces = list()
            for target in targets:
                tar_trace = list()
                for parent in parents:
                    if parent in target.interaction_trace:
                        tar_trace.append(1)
                    else:
                        tar_trace.append(0)
                traces.append(tar_trace)
            return np.array(traces).flatten()



    def get_trace(self, factored_state, action, names=None):
        # gets the trace for a factored state, using the screen. If we don't want to screen to change, use a dummy screen here
        self.set_from_factored_state(factored_state)
        self.step(action)
        return self.current_trace(names = names)

    def get_full_current_trace(self, all_names=""):
        traces = dict()
        if len(all_names) == 0: all_names = self.all_names 
        all_inter_names = [n for n in all_names if n not in {"Reward", "Done"}]
        for target in all_names:
            # if self.can_interact[target]:
            # traces[target] = np.zeros(len(all_inter_names)).tolist()
            target_traces = np.array([int((val in self.object_name_dict[target].interaction_trace) # a different name is in the trace
                                             or (val == target) # add self interactions
                                             ) for val in all_inter_names])
            traces[target] = target_traces
        return traces

    def get_factor_graph(self, all_names="", complete_graph=False):
        if len(all_names) == 0: all_names = self.all_names
        # like the full trace, but returns n x n+1 (no action, reward or done rows)
        traces = list()
        all_inter_names = all_names if complete_graph else [n for n in all_names if n not in {"Reward", "Done"}] 
        for target in all_names:
            if complete_graph or target not in non_state_factors or target == "Goal":
                # if self.can_interact[target]:
                # traces[target] = np.zeros(len(all_inter_names)).tolist()
                target_traces = np.array([int((val in self.object_name_dict[target].interaction_trace) # a different name is in the trace
                                                or (val == target) # add self interactions
                                                ) for val in all_inter_names])
                traces.append(target_traces)
        return np.stack(traces, axis=0)


    def get_full_trace(self, factored_state, action, outcome_variable="", all_names=""):
        if len(all_names) == 0: all_names = self.all_names
        if "VALID_NAMES" in factored_state: valid = [all_names[int(i)] for i,b in enumerate(factored_state["VALID_NAMES"]) if b != 0] # don't include reward or done in validity vector
        else: valid = all_names

        # print("full trace names", valid, [all_names[int(i)] for i,b in enumerate(factored_state["VALID_NAMES"]) if b != 0], [i for i,b in enumerate(factored_state["VALID_NAMES"]) if b != 0])
        self.set_from_factored_state(factored_state, valid_names = valid)
        # print("stepping", factored_state["Ball"], factored_state["Block0"], factored_state["Block1"], factored_state["Block2"])
        self.step(action)
        factored_state = self.get_state()['factored_state']
        # print("stepped", factored_state["Ball"], factored_state["Block0"], factored_state["Block1"], factored_state["Block2"])
        all_inter_names = [n for n in all_names if n not in {"Reward", "Done"}]
        traces = self.get_full_current_trace()
        return traces

    def demonstrate(self):
        '''
        gives an image and gets a keystroke action
        '''
        return 0

    def toString(self, extracted_state):
        '''
        converts an extracted state into a string for printing. Note this might be overriden since self.objects is not a guaranteed attribute
        '''
        estring = "ITR:" + str(self.itr) + "\t"
        for i, obj in enumerate(self.objects):
            estring += obj.name + ":" + " ".join(map(str, extracted_state[obj.name])) + "\t" # TODO: attributes are limited to single floats
        if "VALID_NAMES" in extracted_state: # TODO: stores valid names in the factored state for now
            estring += "VALID_NAMES:" + " ".join(map(str, extracted_state['VALID_NAMES'])) + "\t"
        if "TRACE" in extracted_state:
            estring += "TRACE:" + " ".join(map(str, np.array(extracted_state['TRACE']).flatten())) + "\t"
        # estring += "Reward:" + str(float(extracted_state["Reward"])) + "\t"
        # estring += "Done:" + str(int(extracted_state["Done"])) + "\t"
        return estring

    def valid_binary(self, valid_names):
        return np.array([(1 if n in valid_names else 0) for n in self.all_names])

    def name_indices(self, names):
        indices = list()
        for n in names:
            indices.append(self.all_names.find(n))
        return indices

    def sample_reward(self):
        return None

    def set_goal_params(self, goal_params):
        # sets parameters like the goal radius, set in subclass
        if hasattr(self, "goal" ) and self.goal is not None: 
            self.goal.set_goal_epsilon(goal_params["radius"])
            self.goal_epsilon = goal_params["radius"]


def convert_env_rl(env):
    env.num_factors = env.num_objects - 3 # Action, Reward, Done
    env.factor_spaces = list()
    env.dict_obs_space = dict()
    env.breakpoints = [0]
    for n in env.all_names:
        if n not in non_state_factors:
            obj_class = strip_instance(n)
            env.factor_spaces.append(gym.spaces.Box(low=env.object_range[obj_class][0], high = env.object_range[obj_class][1]))
            env.dict_obs_space[n] = gym.spaces.Box(low=env.object_range[obj_class][0], high = env.object_range[obj_class][1])
            env.breakpoints.append(env.breakpoints[-1] + len(env.object_range[obj_class][0]))
    env.dict_obs_space = gym.spaces.Dict(env.dict_obs_space)
    env.breakpoints = np.array(env.breakpoints)
    env.observation_space = gym.spaces.Box(low=np.concatenate([fs.low for fs in env.factor_spaces]),
                                            high=np.concatenate([fs.high for fs in env.factor_spaces]))
    return env

class EnvObject():
    def __init__(self, name):
        self.name = name
        self.interaction_trace = list()
        self.state = None
        self.attribute = None
    
    def set_state(self, state=None):
        return self.state
    
    def get_state(self):
        return self.state

    def set_interaction_from_trace(self, trace, all_names):
        self.interaction_trace = np.array(all_names)[trace.nonzero()[0]].tolist()

class Action(EnvObject):
    def __init__(self, continuous, size):
        super().__init__("Action")
        self.attribute = np.zeros(size) if continuous else 0
        self.continuous = continuous
        self.interaction_trace = list()

    def take_action(self, action):
        self.attribute = action
    
    def get_state(self):
        return np.array(self.attribute) if self.continuous else np.array([self.attribute])

class Done(EnvObject):
    def __init__(self):
        super().__init__("Done")
        self.attribute = False
        self.interaction_trace = list()

    def get_state (self):
        return np.array([self.attribute])
    
    def interact (self, other):
        self.interaction_trace.append(other.name)

class Reward(EnvObject):
    def __init__(self):
        super().__init__("Reward")
        self.attribute = 0.0
        self.interaction_trace = list()

    def get_state (self):
        return np.array([self.attribute])

    def interact (self, other):
        self.interaction_trace.append(other.name)

class Goal(EnvObject):
    def __init__(self, **kwargs):
        super().__init__("Goal")
        self.attribute = np.zeros(1) # wrong dimensions until sample_goal is called
        self.interaction_trace = list()
        self.target_idx = -3 # should be the index of the target index, SET IN SUBCLASS
        self.partial = 1 # should be the length of the indices to use as goals, SET IN SUBCLASS
        self.goal_epsilon = -1 # should be set in subclass
    
    def generate_bounds(self):
        return self.bounds[:self.partial], np.array([1] * self.partial + [0] * (self.bounds.shape[0] - self.partial))

    def sample_goal(self, reset_state):
        return np.random.rand()
    
    def get_achieved_goal(self, env):
        longest = max([len(env.object_name_dict[n].get_state()) for n in self.all_names])
        state = np.stack([np.pad(env.object_name_dict[n].get_state(), (0,longest - env.object_name_dict[n].get_state().shape[0])) for n in self.all_names], axis=0)
        return self.get_achieved_goal_state(state)

    def get_achieved_goal_state(self, object_state, fidx=None):
        return object_state[...,self.target_idx,:self.partial]

    def add_interaction(self, reached_goal):
        if reached_goal:
            self.interaction_trace += ["Target"]

    def get_state(self):
        return self.attribute # np.array([self.goal_epsilon])

    def set_state(self, goal=None):
        if goal is not None: self.attribute = goal
        return self.attribute # np.array([self.goal_epsilon])
    
    def set_goal_epsilon(self, goal_epsilon):
        self.goal_epsilon = goal_epsilon

    def check_goal(self, env):
        # returns True if all dimensions are less than epsilon
        return np.all(np.square(self.get_achieved_goal(env) - self.attribute) < self.goal_epsilon)

